[BugFix] Fixes register_save_hook bug#3340
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/rl/3340
Note: Links to docs will display an error until the docs builds have been completed. This comment was automatically generated by Dr. CI and updates every 15 minutes. |
vmoens
left a comment
There was a problem hiding this comment.
Hey @ParamThakkar123 👋
Thank you so much for tackling this bug and additions!
This was a real gap in the API and your PR correctly identified the issue from #3247.
I've gone ahead and patched a few things in your implementation – my apologies for stepping on your code, but I wanted to make sure the fix was solid before merging. I hope you don't mind!
Changes I made:
- Moved the hook lists from class-level to instance-level in
StorageCheckpointerBase.__init__()(the original class-level lists would have caused hooks to be shared across all checkpointer instances) - Consolidated the
_get_shift_from_last_cursorand shift/is_full computation into the base class to avoid duplication betweenTensorStorageCheckpointerandFlatStorageCheckpointer - Simplified
Storage.register_save_hook()to just forward to the checkpointer - Updated the test to avoid local imports and use a proper class instead of
mock.Mock()
The core idea of your fix was exactly right – we just needed to plug a few architectural details. Thanks again for the contribution!
vmoens
left a comment
There was a problem hiding this comment.
Hey @ParamThakkar123 👋
Thank you so much for tackling this bug and additions!
This was a real gap in the API and your PR correctly identified the issue from #3247.
I've gone ahead and patched a few things in your implementation – my apologies for stepping on your code, but I wanted to make sure the fix was solid before merging. I hope you don't mind!
Changes I made:
- Moved the hook lists from class-level to instance-level in
StorageCheckpointerBase.__init__()(the original class-level lists would have caused hooks to be shared across all checkpointer instances) - Consolidated the
_get_shift_from_last_cursorand shift/is_full computation into the base class to avoid duplication betweenTensorStorageCheckpointerandFlatStorageCheckpointer - Simplified
Storage.register_save_hook()to just forward to the checkpointer - Updated the test to avoid local imports and use a proper class instead of
mock.Mock()
The core idea of your fix was exactly right – we just needed to plug a few architectural details. Thanks again for the contribution!
1b2c730 to
ce6806d
Compare
|
Thank you so much @vmoens 🫡 . Happy to contribute and work more on making this repo better |
Description
Describe your changes in detail.
Add
register_save_hookandregister_load_hookmethods to Storage in checkpointer.py to fix attribute errors occuring while usingLazyMemmapStorageMotivation and Context
Why is this change required? What problem does it solve?
If it fixes an open issue, please link to the issue here.
You can use the syntax
close #15213if this solves the issue #15213Fixes #3247
Types of changes
What types of changes does your code introduce? Remove all that do not apply:
Checklist
Go over all the following points, and put an
xin all the boxes that apply.If you are unsure about any of these, don't hesitate to ask. We are here to help!